{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 读取CSV"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "\n",
    "with open(r'..\\fashionAI_keypoints_test\\test.csv') as f:\n",
    "    csv_data = pd.read_csv(f)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 载入模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using TensorFlow backend.\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "import utils\n",
    "import skimage\n",
    "import model as modellib\n",
    "from config import Config\n",
    "\n",
    "PART_INDEX = {'blouse': [0, 1, 2, 3, 4, 5, 6, 9, 10, 11, 12, 13, 14],\n",
    "              'outwear': [0, 1, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14],\n",
    "              'dress': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 17, 18],\n",
    "              'skirt': [15, 16, 17, 18],\n",
    "              'trousers': [15, 16, 19, 20, 21, 22, 23]}\n",
    "PART_STR = ['neckline_left', 'neckline_right',\n",
    "            'center_front',\n",
    "            'shoulder_left', 'shoulder_right',\n",
    "            'armpit_left', 'armpit_right',\n",
    "            'waistline_left', 'waistline_right',\n",
    "            'cuff_left_in', 'cuff_left_out',\n",
    "            'cuff_right_in', 'cuff_right_out',\n",
    "            'top_hem_left', 'top_hem_right',\n",
    "            'waistband_left', 'waistband_right',\n",
    "            'hemline_left', 'hemline_right',\n",
    "            'crotch',\n",
    "            'bottom_left_in', 'bottom_left_out',\n",
    "            'bottom_right_in', 'bottom_right_out']\n",
    "IMAGE_CATEGORY = ['blouse', 'outwear', 'dress', 'skirt', 'trousers'][3]\n",
    "\n",
    "\n",
    "class FIConfig(Config):\n",
    "    \"\"\"\n",
    "    Configuration for training on the toy shapes dataset.\n",
    "    Derives from the base Config class and overrides values specific\n",
    "    to the toy shapes dataset.\n",
    "    \"\"\"\n",
    "    # Give the configuration a recognizable name\n",
    "    NAME = IMAGE_CATEGORY\n",
    "\n",
    "    # Train on 1 GPU and 8 images per GPU. We can put multiple images on each\n",
    "    # GPU because the images are small. Batch size is 8 (GPUs * images/GPU).\n",
    "    GPU_COUNT = 1\n",
    "    IMAGES_PER_GPU = 1\n",
    "    NUM_KEYPOINTS = len(PART_INDEX[IMAGE_CATEGORY])  # 更改当前训练关键点数目\n",
    "    KEYPOINT_MASK_SHAPE = [56, 56]\n",
    "\n",
    "    # Number of classes (including background)\n",
    "    NUM_CLASSES = 1 + 1\n",
    "\n",
    "    RPN_TRAIN_ANCHORS_PER_IMAGE = 100\n",
    "    VALIDATION_STPES = 100\n",
    "    STEPS_PER_EPOCH = 1000\n",
    "    MINI_MASK_SHAPE = (56, 56)\n",
    "    KEYPOINT_MASK_POOL_SIZE = 7\n",
    "\n",
    "    # Pooled ROIs\n",
    "    POOL_SIZE = 7\n",
    "    MASK_POOL_SIZE = 14\n",
    "    MASK_SHAPE = [28, 28]\n",
    "    WEIGHT_LOSS = True\n",
    "    KEYPOINT_THRESHOLD = 0.005\n",
    "\n",
    "\n",
    "inference_config = FIConfig()\n",
    "\n",
    "# Recreate the model in inference mode\n",
    "model = modellib.MaskRCNN(mode=\"inference\", \n",
    "                          config=inference_config,\n",
    "                          model_dir='./logs_{}'.format(IMAGE_CATEGORY))\n",
    "\n",
    "# Get path to saved weights\n",
    "# print(model.find_last())\n",
    "model.load_weights(model.find_last()[1], by_name=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 预测数据"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "共有 5576 张图片等待处理... ...\n"
     ]
    }
   ],
   "source": [
    "import math\n",
    "\n",
    "col = ['image_id', 'image_category'] + PART_STR                                        # 列标签\n",
    "images_path = csv_data[csv_data.image_category.isin([IMAGE_CATEGORY])].image_id        # 路径\n",
    "kps = np.empty([images_path.shape[0], 26]).astype(str)  # 存储容器\n",
    "\n",
    "batch_size = inference_config.GPU_COUNT * inference_config.IMAGES_PER_GPU     # model calss硬性要求bs的计算方法\n",
    "steps = math.ceil(images_path.index[-1] - images_path.index[0] / batch_size)  # 循环次数\n",
    "\n",
    "print(\"共有 {} 张图片等待处理... ...\".format(images_path.index[-1] - images_path.index[0]))\n",
    "for step in range(int(steps)):\n",
    "    start_index = step * batch_size\n",
    "    if start_index % 100 == 0 and start_index != 0:\n",
    "        print(\"  正在生成第 {} 张关键点结果... ...\".format(start_index))\n",
    "        print(kps[start_index-1])\n",
    "    if step != steps-1:\n",
    "        paths = [os.path.join('../fashionAI_keypoints_test', path)\n",
    "                  for path in images_path[images_path.index[start_index:start_index+batch_size]]]\n",
    "    else:\n",
    "        paths = [os.path.join('../fashionAI_keypoints_test', path)\n",
    "                  for path in images_path[images_path.index[start_index:]]]\n",
    "    images = [skimage.io.imread(path) for path in paths]\n",
    "    logits = model.detect_keypoint(images, verbose=0)\n",
    "\n",
    "#     results = [['_'.join(list(p.astype(str)))\n",
    "#                 for p in logits[i]['keypoints'][0]] for i in range(batch_size)]\n",
    "\n",
    "    results = []\n",
    "    for i in range(batch_size):\n",
    "        try:\n",
    "            for p in logits[i]['keypoints'][0]:\n",
    "                results.extend(['_'.join(list(p.astype(str)))])\n",
    "        except IndexError as e:\n",
    "            print(\"第 {} 轮图片出现异常\".format(step))\n",
    "            result = ['-1_-1_-1' for i in range(len(PART_INDEX[IMAGE_CATEGORY]))]\n",
    "\n",
    "    image_id = [path.split('/')[-1] for path in paths]\n",
    "    image_category = [path.split('/')[-2] for path in paths]\n",
    "\n",
    "    for i, (id_, category) in enumerate(zip(image_id, image_category)):\n",
    "        info_arr = np.array([id_, category] + ['-1_-1_-1' for j in range(24)])\n",
    "        info_arr[np.array(PART_INDEX[IMAGE_CATEGORY]) + 2] = np.array(results)\n",
    "        kps[i+start_index] = info_arr\n",
    "\n",
    "kps_df = pd.DataFrame(kps, columns=col)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 记录"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "kps_df.to_csv(r'./{}.csv'.format(IMAGE_CATEGORY))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
